{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Structured Generation from Images or Documents Using Vision Language Models\n",
    "\n",
    "We will be using the SmolVLM-Instruct model from HuggingFaceTB to extract structured information from documents. We will run the VLM using the Hugging Face Transformers library and the [Outlines library](https://github.com/dottxt-ai/outlines), which facilitates structured generation based on limiting token sampling probabilities. \n",
    "\n",
    "> This approach is based on a [Outlines tutorial](https://dottxt-ai.github.io/outlines/latest/cookbook/atomic_caption/).\n",
    "\n",
    "## Dependencies and imports\n",
    "\n",
    "First, let's install the necessary libraries."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install accelerate outlines transformers torch flash-attn datasets sentencepiece"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's continue with importing the necessary libraries."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import outlines\n",
    "import torch\n",
    "\n",
    "from datasets import load_dataset\n",
    "from outlines.models.transformers_vision import transformers_vision\n",
    "from transformers import AutoModelForImageTextToText, AutoProcessor\n",
    "from pydantic import BaseModel"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialising our model\n",
    "\n",
    "We will start by initialising our model from [HuggingFaceTB/SmolVLM-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM-Instruct). Outlines expects us to pass in a model class and processor class, so we will make this example a bit more generic by creating a function that returns those. Alternatively, you could look at the model and tokenizer config within the [Hub repo files](https://huggingface.co/HuggingFaceTB/SmolVLM-Instruct/tree/main), and import those classes directly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some kwargs in processor config are unused and will not have any effect: image_seq_len. \n",
      "Some kwargs in processor config are unused and will not have any effect: image_seq_len. \n"
     ]
    }
   ],
   "source": [
    "model_name = \"HuggingFaceTB/SmolVLM-Instruct\"\n",
    "\n",
    "\n",
    "def get_model_and_processor_class(model_name: str):\n",
    "    model = AutoModelForImageTextToText.from_pretrained(model_name)\n",
    "    processor = AutoProcessor.from_pretrained(model_name)\n",
    "    classes = model.__class__, processor.__class__\n",
    "    del model, processor\n",
    "    return classes\n",
    "\n",
    "\n",
    "model_class, processor_class = get_model_and_processor_class(model_name)\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    device = \"cuda\"\n",
    "elif torch.backends.mps.is_available():\n",
    "    device = \"mps\"\n",
    "else:\n",
    "    device = \"cpu\"\n",
    "\n",
    "model = transformers_vision(\n",
    "    model_name,\n",
    "    model_class=model_class,\n",
    "    device=device,\n",
    "    model_kwargs={\"torch_dtype\": torch.bfloat16, \"device_map\": \"auto\"},\n",
    "    processor_kwargs={\"device\": device},\n",
    "    processor_class=processor_class,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Structured Generation\n",
    "\n",
    "Now, we are going to define a function that will define how the output of our model will be structured. We will be using the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset), which contains a set of images along with questions and their chosen and rejected reponses. This is an okay dataset but we want to create additional text-image-to-text data on top of the images to get our own structured dataset, and potentially fine-tune our model on it. We will use the model to generate a caption, a question and a simple quality tag for the image. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ImageData(BaseModel):\n",
    "    quality: str\n",
    "    description: str\n",
    "    question: str\n",
    "\n",
    "structured_generator = outlines.generate.json(model, ImageData)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, let's come up with an extraction prompt."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = \"\"\"\n",
    "You are an image analysis assisant.\n",
    "\n",
    "Provide a quality tag, a description and a question.\n",
    "\n",
    "The quality can either be \"good\", \"okay\" or \"bad\".\n",
    "The question should be concise and objective.\n",
    "\n",
    "Return your response as a valid JSON object.\n",
    "\"\"\".strip()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's load our image dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['ds_name', 'image', 'question', 'chosen', 'rejected', 'origin_dataset', 'origin_split', 'idx', 'image_path'],\n",
       "    num_rows: 10\n",
       "})"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset = load_dataset(\"openbmb/RLAIF-V-Dataset\", split=\"train[:10]\")\n",
    "dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, let's define a function that will extract the structured information from the image. We will format the prompt using the `apply_chat_template` method and pass it to the model along with the image after that."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/davidberenstein/Documents/programming/huggingface/cookbook/.venv/lib/python3.11/site-packages/dill/_dill.py:414: PicklingWarning: Cannot locate reference to <class '__main__.ImageData'>.\n",
      "  StockPickler.save(self, obj, save_persistent_id)\n",
      "/Users/davidberenstein/Documents/programming/huggingface/cookbook/.venv/lib/python3.11/site-packages/dill/_dill.py:414: PicklingWarning: Cannot pickle <class '__main__.ImageData'>: __main__.ImageData has recursive self-references that trigger a RecursionError.\n",
      "  StockPickler.save(self, obj, save_persistent_id)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e1d431b922334b0297195415a11cf68a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/10 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['ds_name', 'image', 'question', 'chosen', 'rejected', 'origin_dataset', 'origin_split', 'idx', 'image_path', 'synthetic_question', 'synthetic_description', 'synthetic_quality'],\n",
       "    num_rows: 10\n",
       "})"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def extract(row):\n",
    "    messages = [\n",
    "        {\n",
    "            \"role\": \"user\",\n",
    "            \"content\": [{\"type\": \"image\"}, {\"type\": \"text\", \"text\": prompt}],\n",
    "        },\n",
    "    ]\n",
    "\n",
    "    formatted_prompt = model.processor.apply_chat_template(\n",
    "        messages, add_generation_prompt=True\n",
    "    )\n",
    "\n",
    "    result = structured_generator(formatted_prompt, [row[\"image\"]])\n",
    "    row['synthetic_question'] = result.question\n",
    "    row['synthetic_description'] = result.description\n",
    "    row['synthetic_quality'] = result.quality\n",
    "    return row\n",
    "\n",
    "\n",
    "dataset = dataset.map(lambda x: extract(x))\n",
    "dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's now push our new dataset to the Hub."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ab88b1b3bb1441498788bdc2c2b4cf30",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e5e359d02ede43959e92a9e5626f9ffd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/10 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9f7f07dad09f47c5a8dfdeba403845f6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e47600c765b64b55aa6f93e9cf5d077e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "README.md:   0%|          | 0.00/719 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "CommitInfo(commit_url='https://huggingface.co/datasets/davidberenstein1957/structured-generation-information-extraction-vlms-openbmb-RLAIF-V-Dataset/commit/373d6a25e8301077773fc6a37899b1598cf6f8cd', commit_message='Upload dataset', commit_description='', oid='373d6a25e8301077773fc6a37899b1598cf6f8cd', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/davidberenstein1957/structured-generation-information-extraction-vlms-openbmb-RLAIF-V-Dataset', endpoint='https://huggingface.co', repo_type='dataset', repo_id='davidberenstein1957/structured-generation-information-extraction-vlms-openbmb-RLAIF-V-Dataset'), pr_revision=None, pr_num=None)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset.push_to_hub(\"davidberenstein1957/structured-generation-information-extraction-vlms-openbmb-RLAIF-V-Dataset\", split=\"train\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<iframe\n",
    "  src=\"https://huggingface.co/datasets/davidberenstein1957/structured-generation-information-extraction-vlms-openbmb-RLAIF-V-Dataset/embed/viewer/default/train?row=3\"\n",
    "  frameborder=\"0\"\n",
    "  width=\"100%\"\n",
    "  height=\"560px\"\n",
    "></iframe>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The results are not perfect, but they are a good starting point to continue exploring with different models and prompts!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "\n",
    "We've seen how to extract structured information from documents using a vision language model. We can use similar extractive methods to extract structured information from documents, using somehting like `pdf2image` to convert the document to images and running information extraction on each image pdf of the page.\n",
    "\n",
    "```python\n",
    "pdf_path = \"path/to/your/pdf/file.pdf\"\n",
    "pages = convert_from_path(pdf_path)\n",
    "for page in pages:\n",
    "    extract_objects = extract_objects(page, prompt)\n",
    "```\n",
    "\n",
    "## Next Steps\n",
    "\n",
    "- Take a look at the [Outlines](https://github.com/outlines-ai/outlines) library for more information on how to use it. Explore the different methods and parameters.\n",
    "- Explore extraction on your own usecase with your own model.\n",
    "- Use a different method of extracting structured information from documents."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
